#!/usr/bin/env python
"""
# Author: Xiong Lei
# Created Time : Thu 16 Jul 2020 07:24:49 PM CST

# File Name: plot.py
# Description:

"""
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns

            
            
def embedding(
        adata, 
        color='celltype', 
        color_map=None, 
        groupby='batch', 
        groups=None, 
        cond2=None, 
        v2=None, 
        save=None, 
        legend_loc='right margin', 
        legend_fontsize=None, 
        legend_fontweight='bold', 
        sep='_', 
        basis='X_umap',
        size=10,
        show=True,
    ):
    """
    plot separated embeddings with others as background
    
    Parameters
    ----------
    adata
        AnnData
    color
        meta information to be shown
    color_map
        specific color map
    groupby
        condition which is based-on to separate
    groups
        specific groups to be shown
    cond2
        another targeted condition
    v2
        another targeted values of another condition
    basis
        embeddings used to visualize, default is X_umap for UMAP
    size
        dot size on the embedding
    """
    
    if groups is None:
        groups = adata.obs[groupby].cat.categories
    for b in groups:
        adata.obs['tmp'] = adata.obs[color].astype(str)
        adata.obs['tmp'][adata.obs[groupby]!=b] = ''
        if cond2 is not None:
            adata.obs['tmp'][adata.obs[cond2]!=v2] = ''
            groups = list(adata[(adata.obs[groupby]==b) & 
                                (adata.obs[cond2]==v2)].obs[color].astype('category').cat.categories.values)
            size = min(size, 120000/len(adata[(adata.obs[groupby]==b) & (adata.obs[cond2]==v2)]))
        else:
            groups = list(adata[adata.obs[groupby]==b].obs[color].astype('category').cat.categories.values)
            size = min(size, 120000/len(adata[adata.obs[groupby]==b]))
        adata.obs['tmp'] = adata.obs['tmp'].astype('category')
        if color_map is not None:
            palette = [color_map[i] if i in color_map else 'gray' for i in adata.obs['tmp'].cat.categories]
        else:
            palette = None

        title = b if cond2 is None else v2+sep+b
        if save is not None:
            save_ = '_'+b+save
            show = False
        else:
            save_ = None
            show = True
        sc.pl.embedding(adata, color='tmp', basis=basis, groups=groups, title=title, palette=palette, size=size, save=save_,
                   legend_loc=legend_loc, legend_fontsize=legend_fontsize, legend_fontweight=legend_fontweight, show=show)
        del adata.obs['tmp']
        del adata.uns['tmp_colors']
        

def plot_meta(
        adata, 
        use_rep=None, 
        color='celltype', 
        batch='batch', 
        colors=None, 
        cmap='Blues', 
        vmax=1, 
        vmin=0, 
        mask=True,
        annot=False, 
        save=None, 
        fontsize=8
    ):
    """
    Plot meta correlations among batches
    
    Parameters
    ----------
    adata
        AnnData
    use_rep
        the cell representations or embeddings used to calculate the correlations, default is `latent` generated by `SCALE v2`
    batch
        the meta information based-on, default is batch
    colors
        colors for each batch
    cmap
        color map for information to be shown
    vmax
        max value
    vmin
        min value
    mask
        value to be masked
    annot
        show specific values
    save
        save the figure
    fontsize
        font size
    """
    meta = []
    name = []
    color = []
    if colors is None:
        colors = ['#FFFF00', '#1CE6FF', '#FF34FF', '#FF4A46', '#008941', '#006FA6', '#A30059', '#FFDBE5', '#7A4900', '#0000A6',
                  '#63FFAC', '#B79762', '#004D43', '#8FB0FF', '#997D87', '#5A0007', '#809693', '#6A3A4C', '#1B4400', '#4FC601',
                  '#3B5DFF', '#4A3B53', '#FF2F80', '#61615A', '#BA0900', '#6B7900', '#00C2A0', '#FFAA92', '#FF90C9', '#B903AA',
                  '#D16100', '#DDEFFF', '#000035', '#7B4F4B', '#A1C299', '#300018', '#0AA6D8', '#013349', '#00846F', '#372101',
                  '#FFB500', '#C2FFED', '#A079BF', '#CC0744', '#C0B9B2', '#C2FF99', '#001E09']
    adata.obs[color] = adata.obs[color].astype('category')
    batches = np.unique(adata.obs[batch])
    for i,b in enumerate(batches):
        for cat in adata.obs[color].cat.categories:
            index = np.where((adata.obs[color]==cat) & (adata.obs[batch]==b))[0]
            if len(index) > 0:
                if use_rep and use_rep in adata.obsm:
                    meta.append(adata.obsm[use_rep][index].mean(0))
                elif use_rep and use_rep in adata.layers:
                    meta.append(adata.layers[use_rep][index].mean(0))
                else:
                    meta.append(adata.X[index].mean(0))
                name.append(cat)
                color.append(colors[i])
    
    
    meta = np.stack(meta)
    plt.figure(figsize=(10, 10))
    corr = np.corrcoef(meta)
    if mask:
        mask = np.zeros_like(corr)
        mask[np.triu_indices_from(mask, k=1)] = True
    grid = sns.heatmap(corr, mask=mask, xticklabels=name, yticklabels=name, annot=annot, # name -> []
                cmap=cmap, square=True, cbar=True, vmin=vmin, vmax=vmax)
    [ tick.set_color(c) for tick,c in zip(grid.get_xticklabels(),color) ]
    [ tick.set_color(c) for tick,c in zip(grid.get_yticklabels(),color) ]
    plt.xticks(rotation=45, horizontalalignment='right', fontsize=fontsize)
    plt.yticks(fontsize=fontsize)

    if save:
        plt.save(save, bbox_inches='tight')
    else:
        plt.show()
        

def plot_meta2(
        adata, 
        use_rep='latent', 
        color='celltype', 
        batch='batch', 
        color_map=None, 
        figsize=(10, 10), 
        cmap='Blues',
        batches=None, 
        annot=False, 
        save=None, 
        cbar=True, 
        keep=False, 
        fontsize=8, 
        vmin=0, 
        vmax=1
    ):
    """
    Plot meta correlations between two batches
    
    Parameters
    ----------
    adata
        AnnData
    use_rep
        the cell representations or embeddings used to calculate the correlations, default is `latent` generated by `SCALE v2`
    batch
        the meta information based-on, default is batch
    colors
        colors for each batch
    cmap
        color map for information to be shown
    vmax
        max value
    vmin
        min value
    mask
        value to be masked
    annot
        show specific values
    save
        save the figure
    fontsize
        font size
    """
    meta = []
    name = []

    adata.obs[color] = adata.obs[color].astype('category')
    if batches is None:
        batches = np.unique(adata.obs[batch]);#print(batches)

    for i,b in enumerate(batches):
        for cat in adata.obs[color].cat.categories:
            index = np.where((adata.obs[color]==cat) & (adata.obs[batch]==b))[0]
            if len(index) > 0:
                if use_rep and use_rep in adata.obsm:
                    meta.append(adata.obsm[use_rep][index].mean(0))
                elif use_rep and use_rep in adata.layers:
                    meta.append(adata.layers[use_rep][index].mean(0))
                else:
                    meta.append(adata.X[index].mean(0))

                name.append(cat)
    
    meta = np.stack(meta)

    plt.figure(figsize=figsize)
    corr = np.corrcoef(meta)
    
    xticklabels = adata[adata.obs[batch]==batches[0]].obs[color].cat.categories
    yticklabels = adata[adata.obs[batch]==batches[1]].obs[color].cat.categories
#     print(len(xticklabels), len(yticklabels))
    corr = corr[len(xticklabels):, :len(xticklabels)] #;print(corr.shape)
    if keep:
        categories = adata.obs[color].cat.categories
        corr_ = np.zeros((len(categories), len(categories)))
        x_ind = [i for i,k in enumerate(categories) if k in xticklabels]
        y_ind = [i for i,k in enumerate(categories) if k in yticklabels]
        corr_[np.ix_(y_ind, x_ind)] = corr
        corr = corr_
        xticklabels, yticklabels = categories, categories
#         xticklabels, yticklabels = [], []
    grid = sns.heatmap(corr, xticklabels=xticklabels, yticklabels=yticklabels, annot=annot,
                cmap=cmap, square=True, cbar=cbar, vmin=vmin, vmax=vmax)

    if color_map is not None:
        [ tick.set_color(color_map[tick.get_text()]) for tick in grid.get_xticklabels() ]
        [ tick.set_color(color_map[tick.get_text()]) for tick in grid.get_yticklabels() ]
    plt.xticks(rotation=45, horizontalalignment='right', fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    plt.xlabel(batches[0], fontsize=fontsize)
    plt.ylabel(batches[1], fontsize=fontsize)
    
    if save:
        plt.save(save, bbox_inches='tight')
    else:
        plt.show()
        
        
        
from sklearn.metrics import confusion_matrix
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, f1_score

def reassign_cluster_with_ref(Y_pred, Y):
    """
    Reassign cluster to reference labels
    
    Parameters
    ----------
    Y_pred: predict y classes
    Y: true y classes
    
    Returns
    -------
    f1_score: clustering f1 score
    y_pred: reassignment index predict y classes
    indices: classes assignment
    """
    def reassign_cluster(y_pred, index):
        y_ = np.zeros_like(y_pred)
        for i, j in index:
            y_[np.where(y_pred==i)] = j
        return y_
#     from sklearn.utils.linear_assignment_ import linear_assignment
    from scipy.optimize import linear_sum_assignment as linear_assignment
#     print(Y_pred.size, Y.size)
    assert Y_pred.size == Y.size
    D = max(Y_pred.max(), Y.max())+1
    w = np.zeros((D,D), dtype=np.int64)
    for i in range(Y_pred.size):
        w[Y_pred[i], Y[i]] += 1
    ind = linear_assignment(w.max() - w)

    return reassign_cluster(Y_pred, ind), ind


def plot_confusion(y, y_pred, save=None, cmap='Blues'):
    """
    Plot confusion matrix
    
    Parameters
    ----------
    y
        ground truth labels
    y_pred 
        predicted labels
    save
        save the figure
    cmap
        color map
        
    Return
    ------
    F1 score
    NMI score
    ARI score
    """
    
    y_class, pred_class_ = np.unique(y), np.unique(y_pred)

    cm = confusion_matrix(y, y_pred)
    f1 = f1_score(y, y_pred, average='micro')
    nmi = normalized_mutual_info_score(y, y_pred)
    ari = adjusted_rand_score(y, y_pred)
    
    cm = cm.astype('float') / cm.sum(axis=0)[np.newaxis, :]

    plt.figure(figsize=(14, 14))
    sns.heatmap(cm, xticklabels=y_class, yticklabels=pred_class,
                    cmap=cmap, square=True, cbar=False, vmin=0, vmax=1)

    plt.xticks(rotation=45, horizontalalignment='right') #, fontsize=14)
    plt.yticks(fontsize=14, rotation=0)
    plt.ylabel('Leiden cluster', fontsize=18)
    
    if save:
        plt.save(save, bbox_inches='tight')
    else:
        plt.show()
    
    return f1, nmi, ari
